package com.hzchendou.blog.demo.interceptor;

import com.hzchendou.blog.demo.params.PageVO;
import java.lang.reflect.Proxy;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.CallableStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.mapping.StatementType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;

/**
 * 分页拦截器
 *
 * @Date: 2022-07-08 15:08
 * @since: 1.0
 */
@Slf4j
@Data
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class,
        Integer.class})})
public class PaginationInterceptor implements Interceptor {

    public static final String DELEGATE_BOUNDSQL_SQL = "delegate.boundSql.sql";
    public static final String DELEGATE_BOUNDSQL = "delegate.boundSql";

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        log.info("拦截StatementHandler.prepare方法, {}", invocation);
        /// 得到代理目标对象
        StatementHandler statementHandler = realTarget(invocation.getTarget());
        /// MetaObject主要用于设置或者是获取对象属性
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        // 先判断是不是SELECT操作  (跳过存储过程和非Select操作)
        MappedStatement mappedStatement = (MappedStatement) metaObject
                .getValue("delegate.mappedStatement");
        if (SqlCommandType.SELECT != mappedStatement.getSqlCommandType()
                || StatementType.CALLABLE == mappedStatement.getStatementType()) {
            return invocation.proceed();
        }
        statementHandler = metaObject.hasGetter("delegate") ? (StatementHandler) metaObject
                .getValue("delegate") : statementHandler;
        if (!(statementHandler instanceof CallableStatementHandler)) {
            // 标记是否修改过 SQL
            BoundSql boundSql = (BoundSql) metaObject.getValue(DELEGATE_BOUNDSQL);
            Object parameterObejct = boundSql.getParameterObject();
            String sql = boundSql.getSql();
            /// 查看是否为分页参数
            if (parameterObejct != null && parameterObejct instanceof PageVO) {
                PageVO param = (PageVO) parameterObejct;
                int offset = (param.getPage() - 1) * param.getSize();
                int size = param.getSize();
                sql = sql + " limit " + offset + "," + size;
                metaObject.setValue(DELEGATE_BOUNDSQL_SQL, sql);
                log.info("完成分页SQL配置, {}", sql);
                //// 查询总数，并将总记录数保存到PageVO对象中
                Connection connection = (Connection)invocation.getArgs()[0];
                countTotal(param, parameterObejct, statementHandler, connection, mappedStatement);
                log.info("完成分页总记录数查询, {}", param);
            }
        }
        return invocation.proceed();
    }


    /**
     * 获得真正的处理对象,可能多层代理.
     */
    @SuppressWarnings("unchecked")
    private static <T> T realTarget(Object target) {
        if (Proxy.isProxyClass(target.getClass())) {
            MetaObject metaObject = SystemMetaObject.forObject(target);
            return realTarget(metaObject.getValue("h.target"));
        }
        return (T) target;
    }


    /**
     * #计算总记录和总分页数
     *
     * @param pageVo
     * @param parameterObject
     * @param connection
     */
    private void countTotal(PageVO pageVo, Object parameterObject,
            StatementHandler statementHandler, Connection connection, MappedStatement mappedStatement) {
        BoundSql boundSql = statementHandler.getBoundSql();
        String sql = boundSql.getSql().toLowerCase();
        // 获取统计SQL
        int startIndex = sql.indexOf("from");
        sql = sql.substring(startIndex);
        sql = "select count(1) " + sql;

        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), sql,
                boundSql.getParameterMappings(), parameterObject);

        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement,
                parameterObject, countBoundSql);
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = connection.prepareStatement(sql);
            parameterHandler.setParameters(pstmt);
            rs = pstmt.executeQuery();
            if (rs.next()) {
                long totalRecord = rs.getLong(1);
                pageVo.setTotal(totalRecord);  // 总记录数
            }
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            try {
                if (rs != null) {
                    rs.close();
                }
                if (pstmt != null) {
                    pstmt.close();
                }
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }
}
